{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🧱 DCGAN - Bricks Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1603ea4b-8345-4e2e-ae7c-01c9953900e8",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own DCGAN on the bricks dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e0d56cc-4773-4029-97d8-26f882ba79c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import (\n",
    "    layers,\n",
    "    models,\n",
    "    callbacks,\n",
    "    losses,\n",
    "    utils,\n",
    "    metrics,\n",
    "    optimizers,\n",
    ")\n",
    "\n",
    "from notebooks.utils import display, sample_batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_SIZE = 64\n",
    "CHANNELS = 1\n",
    "BATCH_SIZE = 128\n",
    "Z_DIM = 100\n",
    "EPOCHS = 300\n",
    "LOAD_MODEL = False\n",
    "ADAM_BETA_1 = 0.5\n",
    "ADAM_BETA_2 = 0.999\n",
    "LEARNING_RATE = 0.0002\n",
    "NOISE_PARAM = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7716fac-0010-49b0-b98e-53be2259edde",
   "metadata": {},
   "source": [
    "## 1. Prepare the data <a name=\"prepare\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73f4c594-3f6d-4c8e-94c1-2c2ba7bce076",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = utils.image_dataset_from_directory(\n",
    "    \"/app/data/lego-brick-images/dataset/\",\n",
    "    labels=None,\n",
    "    color_mode=\"grayscale\",\n",
    "    image_size=(IMAGE_SIZE, IMAGE_SIZE),\n",
    "    batch_size=BATCH_SIZE,\n",
    "    shuffle=True,\n",
    "    seed=42,\n",
    "    interpolation=\"bilinear\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a995473f-c389-4158-92d2-93a2fa937916",
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(img):\n",
    "    \"\"\"\n",
    "    Normalize and reshape the images\n",
    "    \"\"\"\n",
    "    img = (tf.cast(img, \"float32\") - 127.5) / 127.5\n",
    "    return img\n",
    "\n",
    "\n",
    "train = train_data.map(lambda x: preprocess(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80bcdbdd-fb1e-451f-b89c-03fd9b80deb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sample = sample_batch(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa53709f-7f3f-483b-9db8-2e5f9b9942c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(train_sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff50401-3abe-4c10-bba8-b35bc13ad7d5",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 2. Build the GAN <a name=\"build\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9230b5bf-b4a8-48d5-b73b-6899a598f296",
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))\n",
    "x = layers.Conv2D(64, kernel_size=4, strides=2, padding=\"same\", use_bias=False)(\n",
    "    discriminator_input\n",
    ")\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = layers.Conv2D(\n",
    "    128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = layers.Conv2D(\n",
    "    256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = layers.Conv2D(\n",
    "    512, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = layers.Conv2D(\n",
    "    1,\n",
    "    kernel_size=4,\n",
    "    strides=1,\n",
    "    padding=\"valid\",\n",
    "    use_bias=False,\n",
    "    activation=\"sigmoid\",\n",
    ")(x)\n",
    "discriminator_output = layers.Flatten()(x)\n",
    "\n",
    "discriminator = models.Model(discriminator_input, discriminator_output)\n",
    "discriminator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b30dcc08-3869-4b67-a295-61f13d5d4e4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_input = layers.Input(shape=(Z_DIM,))\n",
    "x = layers.Reshape((1, 1, Z_DIM))(generator_input)\n",
    "x = layers.Conv2DTranspose(\n",
    "    512, kernel_size=4, strides=1, padding=\"valid\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    256, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    128, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "x = layers.Conv2DTranspose(\n",
    "    64, kernel_size=4, strides=2, padding=\"same\", use_bias=False\n",
    ")(x)\n",
    "x = layers.BatchNormalization(momentum=0.9)(x)\n",
    "x = layers.LeakyReLU(0.2)(x)\n",
    "generator_output = layers.Conv2DTranspose(\n",
    "    CHANNELS,\n",
    "    kernel_size=4,\n",
    "    strides=2,\n",
    "    padding=\"same\",\n",
    "    use_bias=False,\n",
    "    activation=\"tanh\",\n",
    ")(x)\n",
    "generator = models.Model(generator_input, generator_output)\n",
    "generator.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed493725-488b-4390-8c64-661f3b97a632",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class DCGAN(models.Model):\n",
    "    def __init__(self, discriminator, generator, latent_dim):\n",
    "        super(DCGAN, self).__init__()\n",
    "        self.discriminator = discriminator\n",
    "        self.generator = generator\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "    def compile(self, d_optimizer, g_optimizer):\n",
    "        super(DCGAN, self).compile()\n",
    "        self.loss_fn = losses.BinaryCrossentropy()\n",
    "        self.d_optimizer = d_optimizer\n",
    "        self.g_optimizer = g_optimizer\n",
    "        self.d_loss_metric = metrics.Mean(name=\"d_loss\")\n",
    "        self.d_real_acc_metric = metrics.BinaryAccuracy(name=\"d_real_acc\")\n",
    "        self.d_fake_acc_metric = metrics.BinaryAccuracy(name=\"d_fake_acc\")\n",
    "        self.d_acc_metric = metrics.BinaryAccuracy(name=\"d_acc\")\n",
    "        self.g_loss_metric = metrics.Mean(name=\"g_loss\")\n",
    "        self.g_acc_metric = metrics.BinaryAccuracy(name=\"g_acc\")\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        return [\n",
    "            self.d_loss_metric,\n",
    "            self.d_real_acc_metric,\n",
    "            self.d_fake_acc_metric,\n",
    "            self.d_acc_metric,\n",
    "            self.g_loss_metric,\n",
    "            self.g_acc_metric,\n",
    "        ]\n",
    "\n",
    "    def train_step(self, real_images):\n",
    "        # Sample random points in the latent space\n",
    "        batch_size = tf.shape(real_images)[0]\n",
    "        random_latent_vectors = tf.random.normal(\n",
    "            shape=(batch_size, self.latent_dim)\n",
    "        )\n",
    "\n",
    "        # Train the discriminator on fake images\n",
    "        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
    "            generated_images = self.generator(\n",
    "                random_latent_vectors, training=True\n",
    "            )\n",
    "            real_predictions = self.discriminator(real_images, training=True)\n",
    "            fake_predictions = self.discriminator(\n",
    "                generated_images, training=True\n",
    "            )\n",
    "\n",
    "            real_labels = tf.ones_like(real_predictions)\n",
    "            real_noisy_labels = real_labels + NOISE_PARAM * tf.random.uniform(\n",
    "                tf.shape(real_predictions)\n",
    "            )\n",
    "            fake_labels = tf.zeros_like(fake_predictions)\n",
    "            fake_noisy_labels = fake_labels - NOISE_PARAM * tf.random.uniform(\n",
    "                tf.shape(fake_predictions)\n",
    "            )\n",
    "\n",
    "            d_real_loss = self.loss_fn(real_noisy_labels, real_predictions)\n",
    "            d_fake_loss = self.loss_fn(fake_noisy_labels, fake_predictions)\n",
    "            d_loss = (d_real_loss + d_fake_loss) / 2.0\n",
    "\n",
    "            g_loss = self.loss_fn(real_labels, fake_predictions)\n",
    "\n",
    "        gradients_of_discriminator = disc_tape.gradient(\n",
    "            d_loss, self.discriminator.trainable_variables\n",
    "        )\n",
    "        gradients_of_generator = gen_tape.gradient(\n",
    "            g_loss, self.generator.trainable_variables\n",
    "        )\n",
    "\n",
    "        self.d_optimizer.apply_gradients(\n",
    "            zip(gradients_of_discriminator, discriminator.trainable_variables)\n",
    "        )\n",
    "        self.g_optimizer.apply_gradients(\n",
    "            zip(gradients_of_generator, generator.trainable_variables)\n",
    "        )\n",
    "\n",
    "        # Update metrics\n",
    "        self.d_loss_metric.update_state(d_loss)\n",
    "        self.d_real_acc_metric.update_state(real_labels, real_predictions)\n",
    "        self.d_fake_acc_metric.update_state(fake_labels, fake_predictions)\n",
    "        self.d_acc_metric.update_state(\n",
    "            [real_labels, fake_labels], [real_predictions, fake_predictions]\n",
    "        )\n",
    "        self.g_loss_metric.update_state(g_loss)\n",
    "        self.g_acc_metric.update_state(real_labels, fake_predictions)\n",
    "\n",
    "        return {m.name: m.result() for m in self.metrics}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e898dd8e-f562-4517-8351-fc2f8b617a24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a DCGAN\n",
    "dcgan = DCGAN(\n",
    "    discriminator=discriminator, generator=generator, latent_dim=Z_DIM\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "800a3c6e-fb11-4792-b6bc-9a43a7c977ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "if LOAD_MODEL:\n",
    "    dcgan.load_weights(\"./checkpoint/checkpoint.ckpt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35b14665-4359-447b-be58-3fd58ba69084",
   "metadata": {},
   "source": [
    "## 3. Train the GAN <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "245e6374-5f5b-4efa-be0a-07b182f82d2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dcgan.compile(\n",
    "    d_optimizer=optimizers.Adam(\n",
    "        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
    "    ),\n",
    "    g_optimizer=optimizers.Adam(\n",
    "        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "349865fe-ffbe-450e-97be-043ae1740e78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a model save checkpoint\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint/checkpoint.ckpt\",\n",
    "    save_weights_only=True,\n",
    "    save_freq=\"epoch\",\n",
    "    verbose=0,\n",
    ")\n",
    "\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "\n",
    "class ImageGenerator(callbacks.Callback):\n",
    "    def __init__(self, num_img, latent_dim):\n",
    "        self.num_img = num_img\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        random_latent_vectors = tf.random.normal(\n",
    "            shape=(self.num_img, self.latent_dim)\n",
    "        )\n",
    "        generated_images = self.model.generator(random_latent_vectors)\n",
    "        generated_images = generated_images * 127.5 + 127.5\n",
    "        generated_images = generated_images.numpy()\n",
    "        display(\n",
    "            generated_images,\n",
    "            save_to=\"./output/generated_img_%03d.png\" % (epoch),\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8913a77-f472-4008-9039-dba00e6db980",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dcgan.fit(\n",
    "    train,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[\n",
    "        model_checkpoint_callback,\n",
    "        tensorboard_callback,\n",
    "        ImageGenerator(num_img=10, latent_dim=Z_DIM),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369bde44-2e39-4bc6-8549-a3a27ecce55c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the final models\n",
    "generator.save(\"./models/generator\")\n",
    "discriminator.save(\"./models/discriminator\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26999087-0e85-4ddf-ba5f-13036466fce7",
   "metadata": {},
   "source": [
    "## 3. Generate new images <a name=\"decode\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48e90117-2e0e-4f4b-9138-b25dce9870f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample some points in the latent space, from the standard normal distribution\n",
    "grid_width, grid_height = (10, 3)\n",
    "z_sample = np.random.normal(size=(grid_width * grid_height, Z_DIM))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e185509-3861-425c-882d-4fe16d82d355",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Decode the sampled points\n",
    "reconstructions = generator.predict(z_sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e5e43c0-ef06-4d32-acf6-09f00cf2fa9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw a plot of decoded images\n",
    "fig = plt.figure(figsize=(18, 5))\n",
    "fig.subplots_adjust(hspace=0.4, wspace=0.4)\n",
    "\n",
    "# Output the grid of faces\n",
    "for i in range(grid_width * grid_height):\n",
    "    ax = fig.add_subplot(grid_height, grid_width, i + 1)\n",
    "    ax.axis(\"off\")\n",
    "    ax.imshow(reconstructions[i, :, :], cmap=\"Greys\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59bfd4e4-7fdc-488a-86df-2c131c904803",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_images(img1, img2):\n",
    "    return np.mean(np.abs(img1 - img2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b568995a-d4ad-478c-98b2-d9a1cdb9e841",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = []\n",
    "for i in train.as_numpy_iterator():\n",
    "    all_data.extend(i)\n",
    "all_data = np.array(all_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c4b5bb1-3581-49b3-81ce-920400d6f3f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "r, c = 3, 5\n",
    "fig, axs = plt.subplots(r, c, figsize=(10, 6))\n",
    "fig.suptitle(\"Generated images\", fontsize=20)\n",
    "\n",
    "noise = np.random.normal(size=(r * c, Z_DIM))\n",
    "gen_imgs = generator.predict(noise)\n",
    "\n",
    "cnt = 0\n",
    "for i in range(r):\n",
    "    for j in range(c):\n",
    "        axs[i, j].imshow(gen_imgs[cnt], cmap=\"gray_r\")\n",
    "        axs[i, j].axis(\"off\")\n",
    "        cnt += 1\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51923e98-bf0e-4de4-948a-05147c486b72",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(r, c, figsize=(10, 6))\n",
    "fig.suptitle(\"Closest images in the training set\", fontsize=20)\n",
    "\n",
    "cnt = 0\n",
    "for i in range(r):\n",
    "    for j in range(c):\n",
    "        c_diff = 99999\n",
    "        c_img = None\n",
    "        for k_idx, k in enumerate(all_data):\n",
    "            diff = compare_images(gen_imgs[cnt], k)\n",
    "            if diff < c_diff:\n",
    "                c_img = np.copy(k)\n",
    "                c_diff = diff\n",
    "        axs[i, j].imshow(c_img, cmap=\"gray_r\")\n",
    "        axs[i, j].axis(\"off\")\n",
    "        cnt += 1\n",
    "\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
